{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "code_folding": [],
    "collapsed": true,
    "hidden_ranges": [],
    "jupyter": {
     "outputs_hidden": true
    },
    "originalKey": "4c6694a4-a6f8-4fc6-a9a7-4a29617406cb",
    "showInput": false
   },
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "code_folding": [],
    "collapsed": true,
    "hidden_ranges": [],
    "jupyter": {
     "outputs_hidden": true
    },
    "originalKey": "c901c723-b2f3-4f75-96c0-6555954209f5",
    "showInput": false
   },
   "source": [
    "### Upper Confidence Bound (UCB)\n",
    "\n",
    "The Upper Confidence Bound (UCB) acquisition function balances exploration and exploitation by assigning a score of $\\mu + \\sqrt{\\beta} \\cdot \\sigma$ if the posterior distribution is normal with mean $\\mu$ and variance $\\sigma^2$. This \"analytic\" version is implemented in the `UpperConfidenceBound` class. The Monte Carlo version of UCB is implemented in the `qUpperConfidenceBound` class, which also allows for q-batches of size greater than one. (The derivation of q-UCB is given in Appendix A of [Wilson et. al., 2017](https://arxiv.org/pdf/1712.00424.pdf)).\n",
    "\n",
    "### A scalarized version of q-UCB\n",
    "\n",
    "Suppose now that we are in a multi-output setting, where, e.g., we model the effects of a design on multiple metrics. We first show a simple extension of the q-UCB acquisition function that accepts a multi-output model and performs q-UCB on a scalarized version of the multiple outputs, achieved via a vector of weights. Implementing a new acquisition function in botorch is easy; one simply needs to implement the constructor and a `forward` method."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "code_folding": [],
    "collapsed": false,
    "customOutput": null,
    "executionStartTime": 1668651267039,
    "executionStopTime": 1668651273452,
    "hidden_ranges": [],
    "jupyter": {
     "outputs_hidden": false
    },
    "originalKey": "75503e22-bf3b-49d6-87a8-cb9c741e941e",
    "requestMsgId": "28d53a4f-be29-4ec6-8529-e1bdc80b5adf"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/data/sandcastle/boxes/fbsource/buck-out/v2/gen/fbcode/f3b9a99e517e0a13/bento/kernels/__bento_kernel_axoptics__/bento_kernel_axoptics#link-tree/sympy/solvers/diophantine.py:3188: SyntaxWarning:\n",
      "\n",
      "\"is\" with a literal. Did you mean \"==\"?\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/data/sandcastle/boxes/fbsource/buck-out/v2/gen/fbcode/f3b9a99e517e0a13/bento/kernels/__bento_kernel_axoptics__/bento_kernel_axoptics#link-tree/sympy/plotting/plot.py:520: SyntaxWarning:\n",
      "\n",
      "\"is\" with a literal. Did you mean \"==\"?\n",
      "\n",
      "/data/sandcastle/boxes/fbsource/buck-out/v2/gen/fbcode/f3b9a99e517e0a13/bento/kernels/__bento_kernel_axoptics__/bento_kernel_axoptics#link-tree/sympy/plotting/plot.py:540: SyntaxWarning:\n",
      "\n",
      "\"is\" with a literal. Did you mean \"==\"?\n",
      "\n",
      "/data/sandcastle/boxes/fbsource/buck-out/v2/gen/fbcode/f3b9a99e517e0a13/bento/kernels/__bento_kernel_axoptics__/bento_kernel_axoptics#link-tree/sympy/plotting/plot.py:553: SyntaxWarning:\n",
      "\n",
      "\"is\" with a literal. Did you mean \"==\"?\n",
      "\n",
      "/data/sandcastle/boxes/fbsource/buck-out/v2/gen/fbcode/f3b9a99e517e0a13/bento/kernels/__bento_kernel_axoptics__/bento_kernel_axoptics#link-tree/sympy/plotting/plot.py:560: SyntaxWarning:\n",
      "\n",
      "\"is\" with a literal. Did you mean \"==\"?\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import math\n",
    "from typing import Optional\n",
    "\n",
    "from botorch.acquisition.monte_carlo import MCAcquisitionFunction\n",
    "from botorch.models.model import Model\n",
    "from botorch.sampling.base import MCSampler\n",
    "from botorch.sampling.normal import SobolQMCNormalSampler\n",
    "from botorch.utils import average_over_ensemble_models, t_batch_mode_transform\n",
    "from torch import Tensor\n",
    "\n",
    "\n",
    "class qScalarizedUpperConfidenceBound(MCAcquisitionFunction):\n",
    "    def __init__(\n",
    "        self,\n",
    "        model: Model,\n",
    "        beta: Tensor,\n",
    "        weights: Tensor,\n",
    "        sampler: Optional[MCSampler] = None,\n",
    "    ) -> None:\n",
    "        # we use the AcquisitionFunction constructor, since that of\n",
    "        # MCAcquisitionFunction performs some validity checks that we don't want here\n",
    "        super(MCAcquisitionFunction, self).__init__(model=model)\n",
    "        if sampler is None:\n",
    "            sampler = SobolQMCNormalSampler(sample_shape=torch.Size([512]))\n",
    "        self.sampler = sampler\n",
    "        self.register_buffer(\"beta\", torch.as_tensor(beta))\n",
    "        self.register_buffer(\"weights\", torch.as_tensor(weights))\n",
    "\n",
    "    @t_batch_mode_transform()\n",
    "    @average_over_ensemble_models\n",
    "    def forward(self, X: Tensor) -> Tensor:\n",
    "        \"\"\"Evaluate scalarized qUCB on the candidate set `X`.\n",
    "\n",
    "        Args:\n",
    "            X: A `(b) x q x d`-dim Tensor of `(b)` t-batches with `q` `d`-dim\n",
    "                design points each.\n",
    "\n",
    "        Returns:\n",
    "            Tensor: A `(b)`-dim Tensor of Upper Confidence Bound values at the\n",
    "                given design points `X`.\n",
    "        \"\"\"\n",
    "        posterior = self.model.posterior(X)\n",
    "        samples = self.get_posterior_samples(posterior)  # n x b x q x o\n",
    "        scalarized_samples = samples.matmul(self.weights)  # n x b x q\n",
    "        mean = posterior.mean  # b x q x o\n",
    "        scalarized_mean = mean.matmul(self.weights)  # b x q\n",
    "        ucb_samples = (\n",
    "            scalarized_mean\n",
    "            + math.sqrt(self.beta * math.pi / 2)\n",
    "            * (scalarized_samples - scalarized_mean).abs()\n",
    "        )\n",
    "        return ucb_samples.max(dim=-1)[0].mean(dim=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "code_folding": [],
    "hidden_ranges": [],
    "originalKey": "b1d4ee8e-69e1-4914-b113-473add84f322",
    "showInput": false
   },
   "source": [
    "Note that `qScalarizedUpperConfidenceBound` is very similar to `qUpperConfidenceBound` and only requires a few lines of new code to accomodate scalarization of multiple outputs. The `@t_batch_mode_transform` decorator ensures that the input `X` has an explicit t-batch dimension (code comments are added with shapes for clarity). The `@average_over_ensemble_models` decorator averages the acquisition values of ensemble models over elements of the ensemble, for example for the fully Bayesian SAAS model.\n",
    "\n",
    "See the end of this tutorial for a quick and easy way of achieving the same scalarization effect using `ScalarizedPosteriorTransform`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "originalKey": "7122ce31-f5ee-4fdc-8962-73f692eaaac0",
    "showInput": false
   },
   "source": [
    "#### Ad-hoc testing q-Scalarized-UCB\n",
    "\n",
    "Before hooking the newly defined acquisition function into a Bayesian Optimization loop, we should test it. For this we'll just make sure that it properly evaluates on a compatible multi-output model. Here we just define a basic multi-output `SingleTaskGP` model trained on synthetic data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "code_folding": [],
    "collapsed": false,
    "customOutput": null,
    "executionStartTime": 1668651273753,
    "executionStopTime": 1668651274217,
    "hidden_ranges": [],
    "jupyter": {
     "outputs_hidden": false
    },
    "originalKey": "4958d0f5-cce4-4fc8-8fa2-8b9d7fe2bd5d",
    "requestMsgId": "534c2e0c-f4c0-4bb0-bdec-e44f86c0ed2b"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "from botorch.fit import fit_gpytorch_mll\n",
    "from botorch.models import SingleTaskGP\n",
    "from botorch.utils import standardize\n",
    "from gpytorch.mlls import ExactMarginalLogLikelihood\n",
    "\n",
    "torch.set_default_dtype(torch.double)\n",
    "\n",
    "# generate synthetic data\n",
    "X = torch.rand(20, 2)\n",
    "Y = torch.stack([torch.sin(X[:, 0]), torch.cos(X[:, 1])], -1)\n",
    "\n",
    "# construct and fit the multi-output model\n",
    "gp = SingleTaskGP(train_X=X, train_Y=Y)\n",
    "mll = ExactMarginalLogLikelihood(gp.likelihood, gp)\n",
    "fit_gpytorch_mll(mll)\n",
    "\n",
    "# construct the acquisition function\n",
    "qSUCB = qScalarizedUpperConfidenceBound(gp, beta=0.1, weights=torch.tensor([0.1, 0.5]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false,
    "customOutput": null,
    "executionStartTime": 1668651274500,
    "executionStopTime": 1668651274569,
    "jupyter": {
     "outputs_hidden": false
    },
    "originalKey": "5070d96c-5673-4bbf-ab72-781d1eabd1bf",
    "requestMsgId": "1a5d87ac-e176-46dd-b7b9-fb29bfd14971"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.5625], grad_fn=<MeanBackward1>)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# evaluate on single q-batch with q=3\n",
    "qSUCB(torch.rand(3, 2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false,
    "customOutput": null,
    "executionStartTime": 1668651274799,
    "executionStopTime": 1668651274876,
    "jupyter": {
     "outputs_hidden": false
    },
    "originalKey": "b3e0a586-b362-4ab4-b4c1-7fc0c99a1be4",
    "requestMsgId": "4f73099f-0dd6-4494-9706-71f50e5dbccf",
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.5165, 0.3577], grad_fn=<MeanBackward1>)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# batch-evaluate on two q-batches with q=3\n",
    "qSUCB(torch.rand(2, 3, 2))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    },
    "originalKey": "aa45db76-36cb-42e9-9386-649d9ea5e741",
    "showInput": false
   },
   "source": [
    "### A scalarized version of analytic UCB (`q=1` only)\n",
    "\n",
    "We can also write an *analytic* version of UCB for a multi-output model, assuming a multivariate normal posterior and `q=1`. The new class `ScalarizedUpperConfidenceBound` subclasses `AnalyticAcquisitionFunction` instead of `MCAcquisitionFunction`. In contrast to the MC version, instead of using the weights on the MC samples, we directly scalarize the mean vector $\\mu$ and covariance matrix $\\Sigma$ and apply standard UCB on the univariate normal distribution, which has mean $w^T \\mu$ and variance $w^T \\Sigma w$. In addition to the `@t_batch_transform` decorator, here we are also using `expected_q=1` to ensure the input `X` has a `q=1`.\n",
    "\n",
    "*Note:* BoTorch also provides a `ScalarizedPosteriorTransform` abstraction that can be used with any existing analytic acqusition functions and automatically performs the scalarization we implement manually below. See the end of this tutorial for a usage example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false,
    "customOutput": null,
    "executionStartTime": 1668651275106,
    "executionStopTime": 1668651275190,
    "jupyter": {
     "outputs_hidden": false
    },
    "originalKey": "5528fc1c-5cf1-4cab-9598-9a41e652d6ea",
    "requestMsgId": "22877611-7d10-4e11-8da5-17ae46e362b8"
   },
   "outputs": [],
   "source": [
    "from botorch.acquisition import AnalyticAcquisitionFunction\n",
    "\n",
    "\n",
    "class ScalarizedUpperConfidenceBound(AnalyticAcquisitionFunction):\n",
    "    def __init__(\n",
    "        self,\n",
    "        model: Model,\n",
    "        beta: Tensor,\n",
    "        weights: Tensor,\n",
    "        maximize: bool = True,\n",
    "    ) -> None:\n",
    "        # we use the AcquisitionFunction constructor, since that of\n",
    "        # AnalyticAcquisitionFunction performs some validity checks that we don't want here\n",
    "        super(AnalyticAcquisitionFunction, self).__init__(model)\n",
    "        self.maximize = maximize\n",
    "        self.register_buffer(\"beta\", torch.as_tensor(beta))\n",
    "        self.register_buffer(\"weights\", torch.as_tensor(weights))\n",
    "\n",
    "    @t_batch_mode_transform(expected_q=1)\n",
    "    @average_over_ensemble_models\n",
    "    def forward(self, X: Tensor) -> Tensor:\n",
    "        \"\"\"Evaluate the Upper Confidence Bound on the candidate set X using scalarization\n",
    "\n",
    "        Args:\n",
    "            X: A `(b) x d`-dim Tensor of `(b)` t-batches of `d`-dim design\n",
    "                points each.\n",
    "\n",
    "        Returns:\n",
    "            A `(b)`-dim Tensor of Upper Confidence Bound values at the given\n",
    "                design points `X`.\n",
    "        \"\"\"\n",
    "        self.beta = self.beta.to(X)\n",
    "        batch_shape = X.shape[:-2]\n",
    "        posterior = self.model.posterior(X)\n",
    "        means = posterior.mean.squeeze(dim=-2)  # b x o\n",
    "        scalarized_mean = means.matmul(self.weights)  # b\n",
    "        covs = posterior.mvn.covariance_matrix  # b x o x o\n",
    "        weights = self.weights.view(\n",
    "            1, -1, 1\n",
    "        )  # 1 x o x 1 (assume single batch dimension)\n",
    "        weights = weights.expand(batch_shape + weights.shape[1:])  # b x o x 1\n",
    "        weights_transpose = weights.permute(0, 2, 1)  # b x 1 x o\n",
    "        scalarized_variance = torch.bmm(\n",
    "            weights_transpose, torch.bmm(covs, weights)\n",
    "        ).view(\n",
    "            batch_shape\n",
    "        )  # b\n",
    "        delta = (self.beta.expand_as(scalarized_mean) * scalarized_variance).sqrt()\n",
    "        if self.maximize:\n",
    "            return scalarized_mean + delta\n",
    "        else:\n",
    "            return scalarized_mean - delta"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "originalKey": "54811aec-bcad-4d24-8346-fa69c9e1d4c1",
    "showInput": false
   },
   "source": [
    "#### Ad-hoc testing Scalarized-UCB\n",
    "\n",
    "Notice that we pass in an explicit q-batch dimension for consistency, even though `q=1`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false,
    "customOutput": null,
    "executionStartTime": 1668651275410,
    "executionStopTime": 1668651275490,
    "jupyter": {
     "outputs_hidden": false
    },
    "originalKey": "f7a98f13-7d87-4e87-afb8-a8ff087faeef",
    "requestMsgId": "99dedb82-13c3-45ea-8126-a2daa706a14a"
   },
   "outputs": [],
   "source": [
    "# construct the acquisition function\n",
    "SUCB = ScalarizedUpperConfidenceBound(gp, beta=0.1, weights=torch.tensor([0.1, 0.5]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false,
    "customOutput": null,
    "executionStartTime": 1668651275716,
    "executionStopTime": 1668651275791,
    "jupyter": {
     "outputs_hidden": false
    },
    "originalKey": "78de901c-d537-4b20-8ab2-f1161d12cca3",
    "requestMsgId": "af017973-596f-4bf5-9efe-54062c3c2715"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.4957], grad_fn=<AddBackward0>)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# evaluate on single point\n",
    "SUCB(torch.rand(1, 2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": false,
    "customOutput": null,
    "executionStartTime": 1668651276011,
    "executionStopTime": 1668651276084,
    "jupyter": {
     "outputs_hidden": false
    },
    "originalKey": "3ffbd8b7-11e6-48d1-8d05-eddd3fc8221c",
    "requestMsgId": "79891745-83b5-459c-bfc7-d1683396b748"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.4183, 0.5188, 0.4453], grad_fn=<AddBackward0>)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# batch-evaluate on 3 points\n",
    "SUCB(torch.rand(3, 1, 2))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "code_folding": [],
    "hidden_ranges": [],
    "originalKey": "b9127c73-b1b6-4e46-b593-43068d5015a8",
    "showInput": false
   },
   "source": [
    "### Appendix: Using `ScalarizedPosteriorTransform`\n",
    "\n",
    "Using the `ScalarizedPosteriorTransform` abstraction, the functionality of `ScalarizedUpperConfidenceBound` implemented above can be easily achieved in just a few lines of code. `PosteriorTransform`s can be used with both the MC and analytic acquisition functions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "code_folding": [],
    "collapsed": false,
    "customOutput": null,
    "executionStartTime": 1668651301431,
    "executionStopTime": 1668651301442,
    "hidden_ranges": [],
    "jupyter": {
     "outputs_hidden": false
    },
    "originalKey": "79a4b36f-4b14-4a62-9dc6-883931ceb5d3",
    "requestMsgId": "490a2b3a-24f9-4b49-84f1-7a3851d5944f"
   },
   "outputs": [],
   "source": [
    "from botorch.acquisition.objective import ScalarizedPosteriorTransform\n",
    "from botorch.acquisition.analytic import UpperConfidenceBound\n",
    "\n",
    "pt = ScalarizedPosteriorTransform(weights=torch.tensor([0.1, 0.5]))\n",
    "SUCB = UpperConfidenceBound(gp, beta=0.1, posterior_transform=pt)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
